library(tidyverse) # for data wrangling etc
library(cmdstanr) # for cmdstan
library(brms) # for fitting models in STAN
library(standist) # for exploring distributions
library(coda) # for diagnostics
library(bayesplot) # for diagnostics
library(ggmcmc) # for MCMC diagnostics
library(DHARMa) # for residual diagnostics
library(rstan) # for interfacing with STAN
library(emmeans) # for marginal means etc
library(broom) # for tidying outputs
library(tidybayes) # for more tidying outputs
library(HDInterval) # for HPD intervals
library(ggeffects) # for partial plots
library(broom.mixed) # for summarising models
library(posterior) # for posterior draws
library(ggeffects) # for partial effects plots
library(patchwork) # for multi-panel figures
library(bayestestR) # for ROPE
library(see) # for some plots
library(easystats) # framework for stats, modelling and visualisation
library(mgcv)
library(gratia)
theme_set(theme_grey()) # put the default ggplot theme back
source("helperFunctions.R")Bayesian GAM Part1
1 Preparations
Load the necessary libraries
2 Scenario
This is an entirely fabricated example (how embarrising). So here is a picture of some Red Grouse Chicks to compensate..
| x | y |
|---|---|
| 2 | 3 |
| 4 | 5 |
| 8 | 6 |
| 10 | 7 |
| 14 | 4 |
| x | - a continuous predictor |
| y | - a continuous response |
3 Read in the data
Rows: 5 Columns: 2
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
dbl (2): x, y
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
data_gam (5 rows and 2 variables, 2 shown)
ID | Name | Type | Missings | Values | N
---+------+---------+----------+--------+----------
1 | x | numeric | 0 (0.0%) | 2 | 1 (20.0%)
| | | | 4 | 1 (20.0%)
| | | | 8 | 1 (20.0%)
| | | | 10 | 1 (20.0%)
| | | | 14 | 1 (20.0%)
---+------+---------+----------+--------+----------
2 | y | numeric | 0 (0.0%) | 3 | 1 (20.0%)
| | | | 4 | 1 (20.0%)
| | | | 5 | 1 (20.0%)
| | | | 6 | 1 (20.0%)
| | | | 7 | 1 (20.0%)
---------------------------------------------------
4 Exploratory data analysis
Model formula: \[ \begin{align} y_i &\sim{} \mathcal{N}(\mu_i, \sigma^2)\\ \mu_i &=\beta_0 + f(x_i)\\ f(x_i) &= \sum^k_{j=1}{b_j(x_i)\beta_j} \end{align} \]
where \(\beta_0\) is the y-intercept, and \(f(x)\) indicates an additive smoothing function of \(x\).
Although this is a ficticious example without a clear backstory, given that there are two continous predictors (and that one of these has been identified as a response and the other a predictor), we can assume that we might be interested in investigating the relationship between the two. As such, our typically starting point is to explore the basic trend between the two using a scatterplot.
This does not look like a particularly linear relationship. Lets fit a loess smoother..
And what would a linear smoother look like?
Rather than either a loess or linear smoother, we can also try a Generalized Additive Model (GAM) smoother. Dont pay too much attention to the GAM formula at this stage, this will be discussed later in the Model Fitting section.
ggplot(data_gam, aes(y = y, x = x)) +
geom_point() +
geom_smooth(method = "gam", formula = y ~ s(x, k = 3))Conclusions:
- it is clear that the relationship is not linear.
- it does appear that as x inreases, y initially increases before eventually declining again.
- we could model this with a polynomial, but for this exemple, we will use these data to illustrate the fitting of GAMs.
5 Fit the model
Prior to fitting the GAM, it might be worth gaining a bit of an understanding of what will occur behind the scenes.
Lets say we intended to fit a smoother with three knots. The three knots equate to one at each end of the trend and one in the middle. We could reexpress our predictor (x) as three dummy variables that collectively reflect a spline (in this case, potentially two joined polynomials).
X1 X2 X3 x y
1 1.3554342 1 -1.31122014 2 3
2 0.9289363 1 -0.84292723 4 5
3 0.4755086 1 0.09365858 8 6
4 0.5780165 1 0.56195149 10 7
5 1.3189632 1 1.49853730 14 4
And we could visualize these dummies as three separate components.
OR
brms follows the same basic process as gamm4. That is the smooths are partitioned into two components:
- a penalised component which is treated as a random effect
- an unpenalised component that is treated as a fixed effect
The wiggliness penalty matrix is the precision matrix when the smooth is treated as a random effect The smoothness of a term is determined by estimating the variance of the term
In brms, the default priors are designed to be weakly informative. They are chosen to provide moderate regularisation (to help prevent over-fitting) and help stabilise the computations.
Unlike rstanarm, brms models must be compiled before they start sampling. For most models, the compilation of the stan code takes around 45 seconds.
data_gam.form <- bf(y ~ s(x), family = gaussian())
data_gam.brm <- brm(data_gam.form,
data = data_gam,
iter = 5000,
warmup = 1000,
chains = 3,
thin = 5,
refresh = 0,
backend = "rstan"
)Error in smooth.construct.tp.smooth.spec(object, dk$data, dk$knots): A term has fewer unique covariate combinations than specified maximum degrees of freedom
prior class coef group resp dpar nlpar lb ub
(flat) b
(flat) b sx_1
student_t(3, 5, 2.5) Intercept
student_t(3, 0, 2.5) sds 0
student_t(3, 0, 2.5) sds s(x, k = 3) 0
student_t(3, 0, 2.5) sigma 0
source
default
(vectorized)
default
default
(vectorized)
default
data_gam.brm <- brm(data_gam.form,
data = data_gam,
iter = 5000,
warmup = 1000,
chains = 3, cores = 3,
thin = 5,
refresh = 0,
backend = "rstan"
)Compiling Stan program...
Start sampling
Warning: There were 43 divergent transitions after warmup. See
https://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them.
Warning: Examine the pairs() plot to diagnose sampling problems
Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
https://mc-stan.org/misc/warnings.html#tail-ess
prior class coef group resp dpar nlpar lb ub source
(flat) b default
(flat) b sx_1 (vectorized)
student_t(3, 5, 2.5) Intercept default
student_t(3, 0, 2.5) sds 0 default
student_t(3, 0, 2.5) sds s(x, k = 3) 0 (vectorized)
student_t(3, 0, 2.5) sigma 0 default
sds - standard devation of the wiggly basis function
data_gam.brm <- brm(data_gam.form,
data = data_gam,
prior = prior(normal(0, 2.5), class = "b"),
sample_prior = "only",
iter = 5000,
warmup = 1000,
chains = 3,
thin = 5,
backend = "rstan",
refresh = 0
)Compiling Stan program...
Start sampling
conditional_effects
The following link provides some guidance about defining priors. [https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations]
When defining our own priors, we typically do not want them to be scaled.
If we wanted to define our own priors that were less vague, yet still not likely to bias the outcomes, we could try the following priors (which I have mainly plucked out of thin air):
- \(\beta_0\): normal centred at 164 with a standard deviation of 65
- mean of 164: since
median(fert$YIELD) - sd pf 65: since
mad(fert$YIELD)
- mean of 164: since
- \(\beta_1\): normal centred at 0 with a standard deviation of 2.5
- sd of 2.5: since
2.5*(mad(fert$YIELD)/mad(fert$FERTILIZER))
- sd of 2.5: since
- \(\sigma\): half-t centred at 0 with a standard deviation of 65 OR
- sd pf 65: since
mad(fert$YIELD)
- sd pf 65: since
- \(\sigma\): gamma with shape parameters of 2 and 1
Sample prior only
I will also overlay the raw data for comparison.
# A tibble: 1 × 2
`median(y)` `mad(y)`
<dbl> <dbl>
1 5 1.48
priors <- prior(normal(5, 1.5), class = "Intercept") +
prior(normal(0, 1.5), class = "b") +
prior(student_t(3, 0, 1.5), class = "sigma") +
prior(student_t(3, 0, 10), class = "sds")
data_gam.form <- bf(y ~ s(x, k = 3))
data_gam.brm2 <- brm(data_gam.form,
data = data_gam,
prior = priors,
sample_prior = "only",
iter = 5000,
warmup = 1000,
chains = 3, cores = 3,
thin = 5,
backend = "rstan",
control = list(adapt_delta = 0.99),
refresh = 0
)Compiling Stan program...
Start sampling
Sample prior and posterior
The desired updates require recompiling the model
Compiling Stan program...
Start sampling
Warning: There were 1 divergent transitions after warmup. See
https://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them.
Warning: Examine the pairs() plot to diagnose sampling problems
[1] "b_Intercept" "bs_sx_1" "sds_sx_1" "sigma"
[5] "Intercept" "s_sx_1[1]" "prior_Intercept" "prior_bs"
[9] "prior_sds_sx" "prior_sigma" "lprior" "lp__"
[13] "accept_stat__" "stepsize__" "treedepth__" "n_leapfrog__"
[17] "divergent__" "energy__"
Error in model.matrix.default(f, dat): model frame and formula mismatch in model.matrix()
6 MCMC sampling diagnostics
7 Model validation
data_gam.resids <- make_brms_dharma_res(data_gam.brm3, integerResponse = FALSE)
wrap_elements(~ testUniformity(data_gam.resids)) +
wrap_elements(~ plotResiduals(data_gam.resids, form = factor(rep(1, nrow(data_gam))))) +
wrap_elements(~ plotResiduals(data_gam.resids, quantreg = FALSE)) +
wrap_elements(~ testDispersion(data_gam.resids))Warning in smooth.spline(pred, res, df = 10): not using invalid df; must have 1
< df <= n := #{unique x} = 5
DHARMa nonparametric dispersion test via sd of residuals fitted vs.
simulated
data: simulationOutput
dispersion = 0.11431, p-value = 0.25
alternative hypothesis: two.sided
8 Partial effects plots
9 Model investigation
Warning: There were 1 divergent transitions after warmup. Increasing
adapt_delta above 0.99 may help. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
Family: gaussian
Links: mu = identity; sigma = identity
Formula: y ~ s(x, k = 3)
Data: data_gam (Number of observations: 5)
Draws: 3 chains, each with iter = 5000; warmup = 1000; thin = 5;
total post-warmup draws = 2400
Smoothing Spline Hyperparameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sds(sx_1) 12.84 9.57 1.09 37.47 1.00 1670 1303
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept 5.01 0.49 3.97 6.06 1.00 2292 2049
sx_1 0.30 0.48 -0.75 1.27 1.00 2486 2179
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 1.10 0.62 0.39 2.66 1.00 1366 2137
Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
sds(sx_1) is the sd of the smooth weights (spline coefficients). This determines the amount of ‘wiggliness’, in an analogous way to how the sd of group-level effects in a varying slopes and intercepts model determine the amount of variability among groups in slopes and intercepts. However, the actual numeric value of the sds() is not very practically interpretable, because thinking about the variance of smooth weights for any given data and model seems abstract to me. However, if the value is around zero, then this is like ‘complete-pooling’ of the basis functions, which means that there isn’t much added value of more than a single basis function.
sx_1 is the unpenalized weight (ie coefficient) for one of the “natural” parameterized basis functions. The rest of the basis functions are like varying effects. Again, because the actual numeric value of sxs_1 is the value for the unpenalized coefficient for one of the basis functions, this wouldn’t seem to have a lot of practically interpretable meaning just from viewing this number.
[1] "b_Intercept" "bs_sx_1" "sds_sx_1" "sigma"
[5] "Intercept" "s_sx_1[1]" "prior_Intercept" "prior_bs"
[9] "prior_sds_sx" "prior_sigma" "lprior" "lp__"
[13] "accept_stat__" "stepsize__" "treedepth__" "n_leapfrog__"
[17] "divergent__" "energy__"
data_gam.brm3 |>
as_draws_df() |>
dplyr::select(matches("^b_.*|^bs.*|^sds.*|^sigma$|^s_s.*")) |>
summarise_draws(median,
HDInterval::hdi,
Pl = ~ mean(.x < 0),
Pg = ~ mean(.x > 0)
)Warning: Dropping 'draws_df' class as required metadata was removed.
# A tibble: 5 × 6
variable median lower upper Pl Pg
<chr> <dbl> <dbl> <dbl> <dbl> <dbl>
1 b_Intercept 5.00 3.92 5.96 0 1
2 bs_sx_1 0.321 -0.819 1.18 0.211 0.789
3 sds_sx_1 10.9 0.00208 30.0 0 1
4 sigma 0.932 0.301 2.31 0 1
5 s_sx_1[1] 10.9 -2.07 19.1 0.0633 0.937
10 Further analyses
newdata <- with(data_gam, data.frame(x = c(min(x), 9)))
add_epred_draws(
object = data_gam.brm3, newdata = newdata,
ndraws = 2400
) |>
ungroup() |>
group_by(.draw) |>
summarise(Diff = diff(.epred)) |>
summarise(median_hdci(Diff),
Pl = mean(Diff < 0),
Pg = mean(Diff > 0)
)# A tibble: 1 × 8
y ymin ymax .width .point .interval Pl Pg
<dbl> <dbl> <dbl> <dbl> <chr> <chr> <dbl> <dbl>
1 3.02 -0.661 5.20 0.95 median hdci 0.0525 0.948
newdata <- with(data_gam, data.frame(x = seq(min(x), max(x), length = 1000)))
data_gam.peak <-
add_epred_draws(object = data_gam.brm3, newdata = newdata, ndraws = 1000) |>
ungroup() |>
group_by(.draw) |>
# summarise(x = x[which.max(.epred)]) |>
mutate(diff = .epred - lag(.epred)) |>
summarise(x = x[which.min(abs(diff))]) |>
median_hdci(x, .width = 0.95)
data_gam.peak# A tibble: 1 × 6
x .lower .upper .width .point .interval
<dbl> <dbl> <dbl> <dbl> <chr> <chr>
1 8.52 4.02 14 0.95 median hdci
## lets plot this
data_gam.preds <-
data_gam.brm3 |>
add_epred_draws(newdata = newdata, object = _) |>
ungroup() |>
dplyr::select(-.row, -.chain, -.iteration) |>
group_by(x) |>
summarise_draws(median, HDInterval::hdi) |>
ungroup() |>
mutate(
Flag = between(x, data_gam.peak$.lower, data_gam.peak$.upper),
Grp = data.table::rleid(Flag)
)
data_gam.preds |> head()# A tibble: 6 × 7
x variable median lower upper Flag Grp
<dbl> <chr> <dbl> <dbl> <dbl> <lgl> <int>
1 2 .epred 3.36 1.66 5.73 FALSE 1
2 2.01 .epred 3.37 1.67 5.73 FALSE 1
3 2.02 .epred 3.37 1.68 5.73 FALSE 1
4 2.04 .epred 3.38 1.68 5.72 FALSE 1
5 2.05 .epred 3.39 1.69 5.72 FALSE 1
6 2.06 .epred 3.40 1.70 5.72 FALSE 1
ggplot(data_gam.preds, aes(y = median, x = x)) +
geom_line(aes(colour = Flag, group = Grp)) +
geom_ribbon(aes(ymin = lower, ymax = upper, fill = Flag, group = Grp), alpha = 0.2)Unfortunately, it does not appear that this option provides confidence intervals.
newdata <- with(data_gam, data.frame(x = seq(min(x), max(x), length = 1000)))
data_gam.brm3 |>
add_epred_draws(newdata = newdata, object = _) |>
ungroup() |>
group_by(.draw) |>
mutate(diff = .epred - lag(.epred)) |>
summarise(
maxGrad = max(abs(diff), na.rm = TRUE),
x = x[which.max(diff)]
) |>
summarise_draws(median, HDInterval::hdi)# A tibble: 2 × 4
variable median lower upper
<chr> <dbl> <dbl> <dbl>
1 maxGrad 0.00884 0.000737 0.0144
2 x 2.01 2.01 14
newdata <- with(data_gam, data.frame(x = seq(min(x), max(x), length = 1000)))
data_gam.brm3 |>
add_epred_draws(newdata = newdata, object = _) |>
filter(x > 3, x < 13) |>
ungroup() |>
group_by(.draw) |>
mutate(
diff = .epred - lag(.epred),
diff2 = diff - lag(diff)
) |>
summarise(
maxChange = max(abs(diff2), na.rm = TRUE),
x = x[which.max(diff)]
) |>
summarise_draws(median, HDInterval::hdi)# A tibble: 2 × 4
variable median lower upper
<chr> <dbl> <dbl> <dbl>
1 maxChange 0.0000260 0.00000000545 0.0000430
2 x 3.02 3.02 13.0